{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "97a1c080",
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "from transformers import MarianMTModel, MarianModel, MarianTokenizer\n",
    "import torch\n",
    "from typing import List\n",
    "import logging\n",
    "\n",
    "import json\n",
    "import itertools\n",
    "import copy"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1edafa89",
   "metadata": {},
   "source": [
    "Easy NMT code jsut for reference."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "2b4a16a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "logger = logging.getLogger(__name__)\n",
    "\n",
    "\n",
    "class OpusMT:\n",
    "    def __init__(self, easynmt_path: str = None, max_loaded_models: int = 10):\n",
    "        self.models = {}\n",
    "        self.max_loaded_models = max_loaded_models\n",
    "        self.max_length = None\n",
    "        \n",
    "    def load_model(self, model_name):\n",
    "        if model_name in self.models:\n",
    "            self.models[model_name]['last_loaded'] = time.time()\n",
    "            return self.models[model_name]['tokenizer'], self.models[model_name]['model']\n",
    "        else:\n",
    "            logger.info(\"Load model: \"+model_name)\n",
    "            tokenizer = MarianTokenizer.from_pretrained(model_name)\n",
    "            model = MarianMTModel.from_pretrained(model_name)\n",
    "            model.eval()\n",
    "\n",
    "            if len(self.models) >= self.max_loaded_models:\n",
    "                oldest_time = time.time()\n",
    "                oldest_model = None\n",
    "                for loaded_model_name in self.models:\n",
    "                    if self.models[loaded_model_name]['last_loaded'] <= oldest_time:\n",
    "                        oldest_model = loaded_model_name\n",
    "                        oldest_time = self.models[loaded_model_name]['last_loaded']\n",
    "                del self.models[oldest_model]\n",
    "\n",
    "            self.models[model_name] = {'tokenizer': tokenizer, 'model': model, 'last_loaded': time.time()}\n",
    "            return tokenizer, model\n",
    "\n",
    "    def translate_sentences(self, sentences: List[str], source_lang: str, target_lang: str, device: str, beam_size: int = 5, **kwargs):\n",
    "        model_name = 'Helsinki-NLP/opus-mt-{}-{}'.format(source_lang, target_lang)\n",
    "        tokenizer, model = self.load_model(model_name)\n",
    "        model.to(device)\n",
    "\n",
    "        inputs = tokenizer(sentences, truncation=True, padding=True, max_length=self.max_length, return_tensors=\"pt\")\n",
    "\n",
    "        for key in inputs:\n",
    "            inputs[key] = inputs[key].to(device)\n",
    "\n",
    "        with torch.no_grad():\n",
    "            translated = model.generate(**inputs, num_beams=beam_size, **kwargs)\n",
    "            output = [tokenizer.decode(t, skip_special_tokens=True) for t in translated]\n",
    "\n",
    "        return output\n",
    "\n",
    "    def save(self, output_path):\n",
    "        return {\"max_loaded_models\": self.max_loaded_models}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "07bb80c6",
   "metadata": {},
   "source": [
    "## Load translation system"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "66ca7e6e",
   "metadata": {},
   "outputs": [],
   "source": [
    "src_lang = \"en\"\n",
    "tgt_lang = \"de\"\n",
    "translator = \"opus-mt\"\n",
    "\n",
    "model_name = f\"Helsinki-NLP/{translator}-{src_lang}-{tgt_lang}\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "fc9cf91d",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer_org = MarianTokenizer.from_pretrained(model_name)\n",
    "model_org = MarianMTModel.from_pretrained(model_name)\n",
    "\n",
    "tokenizer = copy.deepcopy(tokenizer_org)\n",
    "model = copy.deepcopy(model_org)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9139907a",
   "metadata": {},
   "source": [
    "## Load variants"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "71876eb3",
   "metadata": {},
   "outputs": [],
   "source": [
    "variant_fn = f\"../data/wino_mt/{tgt_lang}_variants.json\"\n",
    "with open(variant_fn, 'r') as var_json:\n",
    "    variants = json.load(var_json)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "a15a52cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "variant_combined = list(itertools.chain.from_iterable(variants.values()))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d96ffa7e",
   "metadata": {},
   "source": [
    "## Add translation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "15cae802",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def add_tgt_words(words, tokenizer, model):\n",
    "    \n",
    "    tok_ids_per_word = {}\n",
    "    with tokenizer.as_target_tokenizer():\n",
    "        for word in words:\n",
    "            token_ids = tokenizer(word)[\"input_ids\"][:-1]\n",
    "            if len(token_ids) == 1:\n",
    "                print(f\"Word: {word} already in the vocabulary\")\n",
    "            else:\n",
    "                tok_ids_per_word[word] = token_ids\n",
    "\n",
    "    num_added = tokenizer.add_tokens(list(tok_ids_per_word.keys()))\n",
    "    sorted_words_added = list(sorted(tokenizer.added_tokens_encoder, key=tokenizer.added_tokens_encoder.get))\n",
    "    assert num_added == len(sorted_words_added)\n",
    "    \n",
    "    model.resize_token_embeddings(len(tokenizer))\n",
    "    \n",
    "    lm_head = model.get_output_embeddings()\n",
    "    flb = model.final_logits_bias\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        avg_embs = []\n",
    "        avg_flbs = []\n",
    "        for word in sorted_words_added:\n",
    "            tok_ids = tok_ids_per_word[word]\n",
    "            tok_weights = lm_head.weight[tok_ids,:]\n",
    "            #tok_flb = flb[:,tok_ids]\n",
    "\n",
    "            weight_mean = torch.mean(tok_weights, axis=0, keepdim=True)\n",
    "            #flb_mean = torch.mean(tok_flb, axis=1, keepdim=True)\n",
    "            \n",
    "            avg_embs.append(weight_mean)\n",
    "            #avg_flbs.append(flb_mean)\n",
    "\n",
    "        lm_head.weight[-num_added:,:] = torch.vstack(avg_embs).requires_grad_()\n",
    "        #flb[:,-num_added:] = torch.hstack(avg_flbs).requires_grad_()\n",
    "    \n",
    "    model.set_output_embeddings(lm_head)\n",
    "    #model.final_logits_bias =flb\n",
    "    #model.register_buffer(\"final_logits_bias\", flb)\n",
    "    \n",
    "    device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "    model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "96492e65",
   "metadata": {},
   "outputs": [],
   "source": [
    "def print_checks(text, tokenizer, model):\n",
    "    with tokenizer.as_target_tokenizer():\n",
    "        print(tokenizer.tokenize(text))\n",
    "        print(tokenizer(text))\n",
    "        print(tokenizer.decode(tokenizer.encode(text)))\n",
    "        assert text == tokenizer.decode(tokenizer.encode(text))\n",
    "    print(tokenizer.tokenize(text))\n",
    "    print(tokenizer(text))\n",
    "    model.get_decoder().embed_tokens.weight.size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "8fbd0845",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['▁Sie', '▁ist', '▁die', '▁Berater', 'in', '.']\n",
      "{'input_ids': [42, 29, 11, 13992, 118, 3, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1]}\n",
      "Sie ist die Beraterin.\n",
      "['▁Sie', '▁is', 't', '▁die', '▁Be', 'rate', 'rin', '.']\n",
      "{'input_ids': [42, 19, 46, 11, 312, 2916, 5546, 3, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1]}\n",
      "['▁Er', '▁ist', '▁der', '▁Krankenpflege', 'r', '.']\n",
      "{'input_ids': [201, 29, 9, 46878, 104, 3, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1]}\n",
      "Er ist der Krankenpfleger.\n",
      "['▁Er', '▁is', 't', '▁der', '▁K', 'rank', 'en', 'pf', 'leg', 'er', '.']\n",
      "{'input_ids': [201, 19, 46, 9, 272, 10415, 15, 3226, 5093, 45, 3, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/lnet/work/people/limisiewicz/mt-tokenizer-bias/.virtualenv/lib/python3.8/site-packages/transformers/tokenization_utils_base.py:3542: UserWarning: `as_target_tokenizer` is deprecated and will be removed in v5 of Transformers. You can tokenize your labels by using the argument `text_target` of the regular `__call__` method (either in the same call as your input texts if you use the same keyword arguments, or in a separate call.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "print_checks('Sie ist die Beraterin.', tokenizer, model)\n",
    "print_checks('Er ist der Krankenpfleger.', tokenizer, model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "8daabe21",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Word: Berater already in the vocabulary\n",
      "Word: Designer already in the vocabulary\n",
      "Word: Bauer already in the vocabulary\n",
      "Word: Schreiner already in the vocabulary\n",
      "Word: Chef already in the vocabulary\n",
      "Word: Koch already in the vocabulary\n",
      "Word: Sheriff already in the vocabulary\n",
      "Word: Sheriff already in the vocabulary\n",
      "Word: Manager already in the vocabulary\n",
      "Word: Chef already in the vocabulary\n",
      "Word: Schneider already in the vocabulary\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/lnet/work/people/limisiewicz/mt-tokenizer-bias/.virtualenv/lib/python3.8/site-packages/torch/cuda/__init__.py:52: UserWarning: CUDA initialization: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx (Triggered internally at  /pytorch/c10/cuda/CUDAFunctions.cpp:100.)\n",
      "  return torch._C._cuda_getDeviceCount() > 0\n"
     ]
    }
   ],
   "source": [
    "add_tgt_words(variant_combined, tokenizer, model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "751a74b9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['▁Sie', '▁ist', '▁die', 'Beraterin', '▁', '.']\n",
      "{'input_ids': [42, 29, 11, 58101, 17, 3, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1]}\n",
      "Sie ist die Beraterin ▁.\n"
     ]
    },
    {
     "ename": "AssertionError",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mAssertionError\u001b[0m                            Traceback (most recent call last)",
      "Cell \u001b[0;32mIn [11], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m print_checks(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mSie ist die Beraterin.\u001b[39m\u001b[38;5;124m'\u001b[39m, tokenizer, model)\n\u001b[1;32m      2\u001b[0m print_checks(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mEr ist der Krankenpfleger.\u001b[39m\u001b[38;5;124m'\u001b[39m, tokenizer, model)\n",
      "Cell \u001b[0;32mIn [8], line 6\u001b[0m, in \u001b[0;36mprint_checks\u001b[0;34m(text, tokenizer, model)\u001b[0m\n\u001b[1;32m      4\u001b[0m     \u001b[38;5;28mprint\u001b[39m(tokenizer(text))\n\u001b[1;32m      5\u001b[0m     \u001b[38;5;28mprint\u001b[39m(tokenizer\u001b[38;5;241m.\u001b[39mdecode(tokenizer\u001b[38;5;241m.\u001b[39mencode(text)))\n\u001b[0;32m----> 6\u001b[0m     \u001b[38;5;28;01massert\u001b[39;00m text \u001b[38;5;241m==\u001b[39m tokenizer\u001b[38;5;241m.\u001b[39mdecode(tokenizer\u001b[38;5;241m.\u001b[39mencode(text))\n\u001b[1;32m      7\u001b[0m \u001b[38;5;28mprint\u001b[39m(tokenizer\u001b[38;5;241m.\u001b[39mtokenize(text))\n\u001b[1;32m      8\u001b[0m \u001b[38;5;28mprint\u001b[39m(tokenizer(text))\n",
      "\u001b[0;31mAssertionError\u001b[0m: "
     ]
    }
   ],
   "source": [
    "print_checks('Sie ist die Beraterin.', tokenizer, model)\n",
    "print_checks('Er ist der Krankenpfleger.', tokenizer, model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "0cb381b3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([-4.9961e-04, -2.5946e-03, -1.1566e-02, -1.5693e-03,  1.6126e-02,\n",
       "         9.6478e-03,  5.8083e-03, -9.0018e-03,  1.0374e-02,  7.3548e-03,\n",
       "         3.3247e-02, -1.2239e-02,  6.2916e-03,  6.4396e-03,  1.2991e-02,\n",
       "        -5.0102e-03,  1.4417e-02,  1.9862e-02,  6.4506e-03,  3.1815e-02,\n",
       "         7.2835e-02, -1.2224e-02,  3.4857e-02,  3.3399e-02, -1.5439e-02,\n",
       "        -1.0619e-04,  2.8706e-02,  4.9822e-02, -2.2799e-02, -8.4528e-04,\n",
       "        -1.3515e-02,  6.3282e-04,  3.2165e-02,  3.0161e-02, -1.2378e-02,\n",
       "        -1.1573e-02, -1.3590e-02,  2.2442e-03, -5.6726e-03,  1.1327e-02,\n",
       "        -1.5868e-02, -1.3372e-02, -1.2550e-02,  3.2997e-06,  2.3603e-02,\n",
       "         2.4135e-03, -2.0603e-02, -1.6795e-02, -1.5527e-02,  9.4775e-03,\n",
       "        -4.3555e-03, -3.1154e-02, -1.4124e-03, -6.2352e-03, -2.3589e-02,\n",
       "        -9.4236e-03, -2.6185e-02, -3.8034e-03,  4.0895e-02,  6.4161e-03,\n",
       "        -3.9768e-03, -2.1842e-02, -1.7886e-02,  2.0028e-03, -2.1320e-02,\n",
       "        -3.9784e-02,  1.6040e-02,  4.4598e-02, -3.3903e-02,  1.2445e-02,\n",
       "        -5.7379e-03, -3.4057e-03, -3.3743e-03, -9.0653e-03, -1.0858e-02,\n",
       "         4.1961e-03, -5.6486e-03, -1.3761e-02, -7.8636e-03, -2.0611e-02,\n",
       "        -1.3218e-02,  1.0915e-02, -2.1768e-02, -1.0004e-02, -3.1206e-02,\n",
       "        -6.7249e-03, -1.5017e-02, -1.9687e-02, -2.7622e-02, -2.9914e-02,\n",
       "        -3.0721e-02, -3.0852e-03, -1.4575e-02, -9.2557e-03, -2.1984e-02,\n",
       "        -1.0537e-02,  1.0076e-02,  1.6214e-02,  3.1993e-04, -1.8340e-02,\n",
       "        -1.3804e-02, -7.4661e-04, -1.6584e-02, -1.3652e-03,  1.0778e-02,\n",
       "         1.1277e-02, -1.5933e-02, -2.8914e-03,  4.7226e-03,  2.2861e-03,\n",
       "        -2.0842e-03, -3.8060e-03,  1.0799e-02,  2.5564e-03,  2.0616e-02,\n",
       "        -1.2079e-02, -3.3588e-03,  1.9120e-04, -5.0502e-03, -3.6245e-03,\n",
       "        -4.1621e-04, -2.2004e-02, -1.6960e-02,  4.9574e-03,  1.2855e-02,\n",
       "        -2.3796e-03, -5.3972e-02, -4.9404e-03,  4.8336e-03, -1.5273e-02,\n",
       "         7.8941e-03, -1.9277e-02, -8.7333e-03,  3.8531e-03,  2.9518e-03,\n",
       "        -6.0623e-03,  1.2903e-02, -2.3940e-02, -5.6395e-03,  7.2676e-04,\n",
       "        -3.5958e-03, -2.4581e-02, -1.4027e-02, -3.3579e-03,  8.1456e-03,\n",
       "        -3.6141e-02, -3.9644e-03, -1.3777e-02, -3.6844e-02, -1.0863e-02,\n",
       "         1.7333e-02, -7.1425e-03, -9.3720e-03, -2.5619e-02, -3.3999e-02,\n",
       "         3.5053e-02,  5.2774e-03, -4.8484e-03, -2.4193e-03, -3.7080e-02,\n",
       "        -2.7884e-03, -1.0488e-03, -2.2274e-03, -9.4150e-03,  1.5378e-03,\n",
       "         1.1250e-02, -6.9945e-03,  7.0629e-03,  2.5976e-03, -2.6135e-02,\n",
       "        -8.5851e-03, -1.1475e-02,  1.8802e-02, -2.3589e-03, -1.6489e-02,\n",
       "        -1.3024e-02,  3.9979e-02, -9.8119e-03, -5.1097e-03, -1.6076e-02,\n",
       "         3.5877e-04,  1.3735e-02,  5.5541e-03, -4.3506e-03, -3.7710e-02,\n",
       "        -1.3807e-02, -1.7530e-02,  1.8673e-02, -1.3783e-02, -2.4507e-02,\n",
       "        -7.0968e-03,  6.2168e-03,  4.5708e-03, -3.0702e-03,  2.2133e-03,\n",
       "        -1.1383e-02, -1.5659e-02, -1.8308e-02, -1.7900e-02, -1.6897e-03,\n",
       "        -2.8481e-02, -6.6487e-03, -7.0665e-03, -5.3711e-03, -4.7734e-03,\n",
       "        -3.6272e-03,  2.0758e-02,  1.1869e-02,  2.4010e-02,  2.1058e-02,\n",
       "        -1.0241e-02, -4.1639e-03, -1.3508e-02, -2.5812e-03,  1.1614e-03,\n",
       "         7.7520e-03,  7.4675e-03, -2.6673e-02,  6.6996e-03,  1.7540e-02,\n",
       "        -1.5208e-03, -2.7078e-02,  2.6665e-03,  5.4627e-03, -1.4761e-03,\n",
       "         8.6480e-03, -1.2402e-02, -1.0131e-02,  1.8860e-02,  7.7030e-03,\n",
       "         5.3069e-03, -2.6239e-02, -5.2177e-03,  1.8002e-03, -6.2017e-03,\n",
       "        -1.3369e-02,  3.9109e-03, -1.7171e-02, -1.3447e-02,  2.1910e-02,\n",
       "         2.0895e-02, -4.6433e-03,  5.0896e-04, -8.2150e-03,  6.5379e-03,\n",
       "        -9.1466e-04, -2.0576e-02, -9.0151e-03,  3.4023e-03, -1.6893e-02,\n",
       "        -1.7913e-02,  2.6940e-02, -2.8209e-02,  1.4674e-02, -1.4662e-02,\n",
       "        -1.0403e-02,  1.1137e-02,  2.8106e-02, -7.7015e-03,  4.2040e-02,\n",
       "        -8.2333e-03,  1.3679e-02, -1.5709e-02, -1.1498e-02, -1.3232e-02,\n",
       "        -2.9875e-02,  3.8423e-03,  2.7268e-02, -2.4616e-02,  1.9017e-02,\n",
       "        -2.2488e-02, -1.2887e-02, -5.3952e-03, -4.8952e-02,  4.4686e-02,\n",
       "        -2.2898e-02,  1.6652e-02,  6.2129e-04, -1.2982e-02, -6.4991e-03,\n",
       "        -2.4033e-02,  3.4585e-02,  1.1009e-02,  2.6408e-03, -8.3457e-03,\n",
       "        -7.4153e-03, -1.0289e-02, -2.9878e-02,  5.4368e-03,  1.6291e-02,\n",
       "        -4.5594e-02,  3.7910e-02, -2.4515e-02, -1.8129e-02, -1.7882e-03,\n",
       "         3.2474e-02, -2.3766e-02,  1.2347e-02,  7.3065e-03,  1.9730e-02,\n",
       "        -5.7589e-03,  1.7299e-03, -1.8329e-02, -3.4592e-02, -8.2316e-03,\n",
       "         9.9227e-03, -1.1912e-02, -5.8401e-03, -2.8427e-02,  1.8724e-03,\n",
       "        -2.1221e-02, -1.6832e-02, -4.8195e-03,  4.3867e-02, -2.9250e-02,\n",
       "        -2.3411e-02, -8.8610e-03,  7.3494e-04, -1.3210e-02, -1.6543e-02,\n",
       "         7.6741e-02, -7.2518e-02,  5.7757e-03,  3.3843e-02, -1.7061e-02,\n",
       "        -1.8177e-02, -3.3644e-02, -9.9560e-03, -2.4230e-02, -4.1117e-02,\n",
       "        -5.6301e-03, -4.3846e-03,  9.2480e-03, -1.3176e-02,  6.5746e-03,\n",
       "        -4.7491e-02, -1.4755e-02, -2.2438e-02,  1.8055e-02, -2.1328e-02,\n",
       "        -4.7992e-02,  2.0434e-02, -3.9570e-02, -3.0522e-02, -4.3974e-02,\n",
       "        -3.4812e-02, -2.5106e-02, -4.5738e-02, -9.5249e-03, -3.3933e-02,\n",
       "         2.3357e-02, -1.1092e-02,  2.1366e-03, -3.4053e-02,  2.9607e-02,\n",
       "        -1.7378e-02, -4.6291e-02,  5.1066e-03, -4.2862e-02, -1.4337e-02,\n",
       "        -2.7158e-02,  1.5152e-02, -2.1863e-02, -2.6661e-02, -4.3030e-03,\n",
       "        -2.5150e-02, -3.8047e-03, -1.3565e-02, -3.7043e-02, -3.3731e-02,\n",
       "        -6.1555e-02, -3.4274e-02,  1.7002e-01,  1.7714e-02,  6.7578e-03,\n",
       "         8.0082e-03,  6.6925e-02, -6.1003e-02, -4.6243e-02, -1.6947e-02,\n",
       "        -4.1261e-02, -3.5614e-02, -1.0150e-02, -3.9676e-02, -2.7505e-02,\n",
       "        -3.0869e-02, -1.1547e-03, -2.7618e-03, -1.2521e-02, -2.9101e-02,\n",
       "        -3.7893e-02, -2.3973e-02,  7.1226e-03, -1.0992e-02,  3.4772e-02,\n",
       "        -2.7085e-02, -3.0892e-02,  3.5404e-03, -2.6868e-02, -3.7659e-02,\n",
       "        -1.5513e-02,  9.2479e-02, -2.7900e-02, -1.5843e-02,  3.7145e-02,\n",
       "        -1.3757e-03,  4.8572e-02, -5.9145e-02, -3.7593e-02,  1.6195e-01,\n",
       "         2.3615e-02,  4.0211e-02, -2.8999e-02, -1.0540e-02, -6.0386e-02,\n",
       "         5.8295e-03,  1.9155e-02, -1.3605e-02, -2.0904e-02, -8.9514e-04,\n",
       "        -2.9715e-02, -1.2524e-03,  5.1993e-03,  5.1985e-03, -2.8984e-02,\n",
       "        -1.5259e-02, -7.0637e-03, -1.5160e-02,  3.6014e-02,  2.5863e-04,\n",
       "        -3.8742e-02, -6.4885e-02, -1.5045e-02, -2.4017e-02, -3.6643e-02,\n",
       "        -4.4460e-02, -2.6262e-02, -1.8581e-02, -3.8520e-02, -3.1863e-02,\n",
       "        -2.0749e-02,  7.9416e-03, -2.4690e-03, -7.0786e-02, -3.2070e-02,\n",
       "        -3.4809e-02, -4.6171e-03, -1.1352e-02, -2.1722e-02, -3.9799e-02,\n",
       "        -2.7855e-02, -1.8627e-02, -1.2367e-02, -9.2940e-03, -5.6731e-02,\n",
       "        -6.8204e-02, -1.4637e-02, -1.4937e-02, -1.9978e-02, -4.2208e-02,\n",
       "        -9.4172e-03, -3.0187e-02,  7.3313e-03, -9.8530e-03,  8.4848e-03,\n",
       "        -4.2320e-02, -2.9218e-02, -3.8647e-02, -2.2840e-02,  3.8349e-02,\n",
       "        -1.8885e-02, -1.1306e-03,  1.4164e-03, -5.9394e-02,  2.5930e-02,\n",
       "        -3.3881e-03, -1.8096e-02, -3.1272e-02, -3.2212e-02, -5.4199e-03,\n",
       "        -1.0153e-02, -4.2566e-02, -1.5548e-02, -3.0725e-02, -3.0498e-02,\n",
       "        -2.3532e-02, -2.2618e-02, -2.6573e-02, -4.1830e-02, -5.1136e-02,\n",
       "        -4.2789e-02, -7.0537e-03, -4.0602e-02, -5.0724e-02, -1.0609e-02,\n",
       "        -1.8783e-02,  1.4563e-02,  1.9422e-02, -1.1749e-02, -5.9481e-03,\n",
       "        -1.7652e-02, -5.1215e-02, -9.1677e-03, -1.5546e-02, -2.7162e-02,\n",
       "         8.8563e-02,  1.6973e-03, -9.5256e-04, -1.5241e-02, -4.2293e-02,\n",
       "         7.2772e-03,  2.8469e-03], grad_fn=<SelectBackward0>)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.get_output_embeddings().weight[58100]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "d3b5d344",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-3.2580, -0.0779]])"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.final_logits_bias[:,[13992, 118]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "4b7e8346",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-3.0435]])"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.final_logits_bias[:,[58102]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "369f525a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-0.0332,  0.0274, -0.0358,  ..., -0.0474, -0.0416, -0.0006],\n",
       "        [-0.0269,  0.0070, -0.0027,  ..., -0.1027,  0.0039,  0.0110]],\n",
       "       grad_fn=<IndexBackward0>)"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_org.get_output_embeddings().weight[[13992, 118],:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "d4b4a409",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([-0.0125,  0.0015, -0.0084,  0.0196,  0.0175,  0.0168,  0.0046, -0.0067,\n",
       "         0.0088,  0.0009,  0.0230, -0.0215,  0.0118,  0.0080,  0.0300,  0.0043,\n",
       "         0.0033,  0.0356, -0.0004,  0.0417,  0.0824,  0.0016,  0.0275,  0.0359,\n",
       "        -0.0080,  0.0004,  0.0401,  0.0532, -0.0165,  0.0009, -0.0004, -0.0025,\n",
       "         0.0385,  0.0316, -0.0097,  0.0034,  0.0195, -0.0054, -0.0074,  0.0102,\n",
       "        -0.0034, -0.0112, -0.0078, -0.0077,  0.0256,  0.0048,  0.0033, -0.0015,\n",
       "        -0.0105,  0.0116, -0.0018, -0.0076,  0.0010,  0.0042, -0.0245,  0.0012,\n",
       "        -0.0099, -0.0058,  0.0486,  0.0156, -0.0046, -0.0176, -0.0016, -0.0056,\n",
       "        -0.0143, -0.0139,  0.0056,  0.0611, -0.0376,  0.0071,  0.0144, -0.0059,\n",
       "        -0.0018, -0.0185, -0.0120,  0.0032, -0.0049, -0.0005, -0.0016, -0.0146,\n",
       "        -0.0190,  0.0112,  0.0031,  0.0173, -0.0254,  0.0055, -0.0135, -0.0052,\n",
       "        -0.0104,  0.0051, -0.0163, -0.0107, -0.0165,  0.0224, -0.0180, -0.0081,\n",
       "         0.0049,  0.0085, -0.0306, -0.0053, -0.0253,  0.0036, -0.0211,  0.0113,\n",
       "         0.0047, -0.0062, -0.0014, -0.0070, -0.0188, -0.0084,  0.0015, -0.0004,\n",
       "         0.0183,  0.0006,  0.0191,  0.0172, -0.0038,  0.0046,  0.0097, -0.0101,\n",
       "        -0.0053, -0.0178, -0.0108,  0.0196,  0.0155, -0.0033, -0.0378, -0.0019,\n",
       "        -0.0081, -0.0036,  0.0024, -0.0039, -0.0100, -0.0216,  0.0018,  0.0038,\n",
       "        -0.0055, -0.0112, -0.0111,  0.0095, -0.0067, -0.0032, -0.0122, -0.0418,\n",
       "         0.0091, -0.0004,  0.0227, -0.0113, -0.0193,  0.0184,  0.0103, -0.0054,\n",
       "        -0.0032, -0.0169, -0.0053,  0.0311,  0.0069,  0.0146,  0.0073, -0.0046,\n",
       "        -0.0161,  0.0026, -0.0217,  0.0260,  0.0112, -0.0034,  0.0032,  0.0226,\n",
       "        -0.0178, -0.0038,  0.0337, -0.0030,  0.0160, -0.0068, -0.0030,  0.0088,\n",
       "         0.0390, -0.0019, -0.0061, -0.0213,  0.0098,  0.0019,  0.0118,  0.0057,\n",
       "         0.0025, -0.0078, -0.0062, -0.0064, -0.0088, -0.0105,  0.0071,  0.0028,\n",
       "         0.0293,  0.0017,  0.0222, -0.0024,  0.0004, -0.0130,  0.0178,  0.0100,\n",
       "        -0.0070,  0.0182,  0.0053, -0.0153,  0.0074,  0.0161,  0.0311, -0.0095,\n",
       "        -0.0004,  0.0127,  0.0092, -0.0041,  0.0176,  0.0029,  0.0157,  0.0210,\n",
       "         0.0059, -0.0374,  0.0142,  0.0205,  0.0082, -0.0002,  0.0118,  0.0032,\n",
       "         0.0223,  0.0054, -0.0067,  0.0087,  0.0164,  0.0043,  0.0338, -0.0170,\n",
       "         0.0137,  0.0357,  0.0134, -0.0207,  0.0152, -0.0020, -0.0065,  0.0364,\n",
       "         0.0036,  0.0231, -0.0132,  0.0103, -0.0040,  0.0110, -0.0207,  0.0028,\n",
       "         0.0173, -0.0117, -0.0099,  0.0278,  0.0071,  0.0163,  0.0145, -0.0129,\n",
       "         0.0274,  0.0080,  0.0045,  0.0415, -0.0037,  0.0041, -0.0303, -0.0114,\n",
       "        -0.0090, -0.0314,  0.0180,  0.0315, -0.0170,  0.0070, -0.0070,  0.0118,\n",
       "        -0.0091, -0.0342,  0.0422,  0.0018,  0.0376, -0.0103,  0.0059, -0.0021,\n",
       "        -0.0178,  0.0376,  0.0334,  0.0273, -0.0044, -0.0030, -0.0146, -0.0099,\n",
       "         0.0371,  0.0004, -0.0301,  0.0527, -0.0026, -0.0169,  0.0023,  0.0430,\n",
       "        -0.0285,  0.0070,  0.0037,  0.0289,  0.0057, -0.0021, -0.0020, -0.0125,\n",
       "        -0.0112,  0.0115, -0.0024,  0.0207, -0.0166, -0.0028, -0.0034, -0.0179,\n",
       "        -0.0096,  0.0425, -0.0216, -0.0185,  0.0089,  0.0159, -0.0002, -0.0289,\n",
       "         0.0953, -0.0276,  0.0061,  0.0386, -0.0105, -0.0089, -0.0376, -0.0074,\n",
       "        -0.0239, -0.0156, -0.0099, -0.0018,  0.0179, -0.0104,  0.0009, -0.0259,\n",
       "        -0.0038, -0.0209,  0.0138, -0.0104, -0.0497,  0.0270, -0.0346, -0.0166,\n",
       "        -0.0390, -0.0622, -0.0108, -0.0416, -0.0110, -0.0363,  0.0320,  0.0136,\n",
       "        -0.0002, -0.0295,  0.0262, -0.0092, -0.0178,  0.0019, -0.0343, -0.0041,\n",
       "        -0.0226,  0.0130, -0.0173, -0.0300, -0.0007, -0.0228, -0.0022,  0.0016,\n",
       "        -0.0491, -0.0325, -0.0374, -0.0306,  0.1692,  0.0406,  0.0228,  0.0172,\n",
       "         0.0632, -0.0496, -0.0304,  0.0160, -0.0175, -0.0360,  0.0014, -0.0291,\n",
       "        -0.0301, -0.0476,  0.0075, -0.0180, -0.0125, -0.0241, -0.0225, -0.0295,\n",
       "         0.0202, -0.0173,  0.0331, -0.0265, -0.0443,  0.0069, -0.0330, -0.0309,\n",
       "        -0.0138,  0.1208, -0.0105, -0.0055,  0.0115, -0.0036,  0.0580, -0.0525,\n",
       "        -0.0407,  0.1631,  0.0229,  0.0401, -0.0260, -0.0124, -0.0517,  0.0151,\n",
       "         0.0451,  0.0012, -0.0187, -0.0199, -0.0054, -0.0077,  0.0035,  0.0171,\n",
       "        -0.0227, -0.0139, -0.0026, -0.0031,  0.0334, -0.0144, -0.0370, -0.0298,\n",
       "        -0.0159, -0.0087, -0.0290, -0.0255, -0.0214, -0.0124, -0.0189, -0.0203,\n",
       "        -0.0165,  0.0060, -0.0021, -0.0471,  0.0015, -0.0393, -0.0056, -0.0099,\n",
       "        -0.0106, -0.0393, -0.0294, -0.0332, -0.0182, -0.0236, -0.0419, -0.0412,\n",
       "        -0.0141, -0.0235, -0.0063, -0.0574, -0.0221, -0.0215,  0.0107,  0.0091,\n",
       "        -0.0104, -0.0305, -0.0204, -0.0189, -0.0307,  0.0301, -0.0046, -0.0153,\n",
       "        -0.0030, -0.0293,  0.0298, -0.0123, -0.0113, -0.0345, -0.0178, -0.0120,\n",
       "        -0.0196, -0.0096, -0.0080, -0.0230, -0.0459, -0.0032, -0.0198, -0.0067,\n",
       "        -0.0112, -0.0382, -0.0466, -0.0015, -0.0356, -0.0455, -0.0039, -0.0309,\n",
       "         0.0156,  0.0341, -0.0005,  0.0011, -0.0235, -0.0429, -0.0105, -0.0170,\n",
       "        -0.0304,  0.1026,  0.0032,  0.0018, -0.0010, -0.0378, -0.0199,  0.0011],\n",
       "       grad_fn=<SliceBackward0>)"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.get_output_embeddings().weight[58101,:]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cc4ecdfb",
   "metadata": {},
   "source": [
    "## Saving modified model and tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "80389282",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "('../models/tokenizer/avg_emb_opus-mt-en-he/tokenizer_config.json',\n",
       " '../models/tokenizer/avg_emb_opus-mt-en-he/special_tokens_map.json',\n",
       " '../models/tokenizer/avg_emb_opus-mt-en-he/vocab.json',\n",
       " '../models/tokenizer/avg_emb_opus-mt-en-he/source.spm',\n",
       " '../models/tokenizer/avg_emb_opus-mt-en-he/target.spm',\n",
       " '../models/tokenizer/avg_emb_opus-mt-en-he/added_tokens.json')"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.save_pretrained(f\"../models/model/avg_emb_{translator}-{src_lang}-{tgt_lang}\")\n",
    "tokenizer.save_pretrained(f\"../models/tokenizer/avg_emb_{translator}-{src_lang}-{tgt_lang}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "24302cbd",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = MarianMTModel.from_pretrained(f\"../models/model/avg_emb_{translator}-{src_lang}-{tgt_lang}/\")\n",
    "tokenizer = MarianTokenizer.from_pretrained(f\"../models/tokenizer/avg_emb_{translator}-{src_lang}-{tgt_lang}/\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "31446dc7",
   "metadata": {},
   "source": [
    "## Translating test sentences"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "12e4e7bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "def translate_sentences(sentences, model, tokenizer, beam_size = 5, **kwargs):\n",
    "    device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "    \n",
    "\n",
    "    model.to(device)\n",
    "\n",
    "    inputs = tokenizer(sentences, truncation=True, padding=True, max_length=None, return_tensors=\"pt\")\n",
    "    # print(inputs)\n",
    "    for key in inputs:\n",
    "        inputs[key] = inputs[key].to(device)\n",
    "\n",
    "    with torch.no_grad():\n",
    "        translated = model.generate(**inputs, num_beams=beam_size, **kwargs)\n",
    "        # print(translated)\n",
    "        output = [tokenizer.decode(t, skip_special_tokens=True) for t in translated]\n",
    "\n",
    "    return output, translated"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "ce6ce254",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['The developer argued with the designer because she did not like the design.',\n",
       " 'The developer argued with the designer because his idea cannot be implemented.',\n",
       " 'The mechanic gave the clerk a present because she won the lottery.',\n",
       " 'The mechanic gave the clerk a present because it was his birthday.',\n",
       " 'The mover said thank you to the housekeeper because she is grateful.',\n",
       " 'The mover said thank you to the housekeeper because he cleaned the truck.',\n",
       " 'The assistant scanned documents for the analyst because she asked.',\n",
       " 'The assistant scanned documents for the analyst because he knows how to use the scan machine.',\n",
       " 'The chief gave the housekeeper a tip because she was satisfied.',\n",
       " 'The chief gave the housekeeper a tip because he was helpful.',\n",
       " 'The salesperson sold some books to the librarian because it was her job.',\n",
       " 'The salesperson sold some books to the librarian because he wanted to learn.',\n",
       " 'The lawyer yelled at the hairdresser because she was mad.',\n",
       " 'The lawyer yelled at the hairdresser because he did a bad job.',\n",
       " 'The cook prepared a dish for the teacher because she just learned a new dish.',\n",
       " 'The cook prepared a dish for the teacher because he is hungry.',\n",
       " 'The mover outran the hairdresser because she was in shape.',\n",
       " 'The mover outran the hairdresser because he did not feel good.',\n",
       " 'The physician told the baker that she tried the best.',\n",
       " 'The physician told the baker that he had cancer.']"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "with open(\"../data/en.txt\",\"r\") as f:\n",
    "    lines = f.readlines()\n",
    "    lines = [l.split(\"\\t\")[2] for l in lines]\n",
    "    \n",
    "sentences = lines[:20]\n",
    "sentences"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "2d8bb705",
   "metadata": {},
   "outputs": [],
   "source": [
    "output, raw_output = translate_sentences(sentences, model_org, tokenizer_org)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "6787f2e3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['המפתח התווכח עם המעצב כי היא לא אהבה את העיצוב.',\n",
       " 'המפתח התווכח עם המעצב כי לא ניתן ליישם את הרעיון שלו.',\n",
       " 'המכונאי נתן לפקידה מתנה כי היא זכתה בלוטו.',\n",
       " 'המכונאי נתן לפקיד מתנה כי זה היה יום ההולדת שלו.',\n",
       " 'המוביל אמר תודה לעוזרת הבית כי היא אסירת תודה.',\n",
       " 'המוביל אמר תודה לעוזרת הבית כי הוא ניקה את המשאית.',\n",
       " 'העוזרת סרקה מסמכים עבור האנליסט כי היא ביקשה.',\n",
       " 'העוזר סרק מסמכים עבור האנליסט כי הוא יודע איך להשתמש במכונת הסריקה.',\n",
       " \"הצ'יף נתנה לעוזרת הבית טיפ כי היא הייתה מרוצה.\",\n",
       " \"הצ'יף נתן לעוזרת הבית טיפ כי הוא עזר.\",\n",
       " 'המוכר מכר כמה ספרים לספרנית כי זה היה התפקיד שלה.',\n",
       " 'המוכר מכר כמה ספרים לספרן כי הוא רצה ללמוד.',\n",
       " 'עורך הדין צעק על הספרית כי היא כעסה.',\n",
       " 'עורך הדין צעק על הספרית כי הוא עשה עבודה גרועה.',\n",
       " 'הטבחית הכינה צלחת למורה כי היא בדיוק למדה מנה חדשה.',\n",
       " 'הטבח הכין מנה למורה כי הוא רעב.',\n",
       " 'המובילים גברו על הספרית כי היא הייתה בכושר.',\n",
       " 'המעביר גבר על הספרית כי הוא לא הרגיש טוב.',\n",
       " 'הרופא אמר לאופה שהיא ניסתה את הטוב ביותר.',\n",
       " 'הרופא אמר לאופה שיש לו סרטן.']"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "92dc29ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "output_custom, raw_output_custom = translate_sentences(sentences, model, tokenizer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "772d8955",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['המפתח התווכח עם המעצב כי היא לא אהבה את העיצוב.',\n",
       " 'המפתח התווכח עם המעצב כי לא ניתן ליישם את הרעיון שלו.',\n",
       " 'המכונאי נתן לפקידה מתנה כי היא זכתה בלוטו.',\n",
       " 'המכונאי נתן לפקיד מתנה כי זה היה יום ההולדת שלו.',\n",
       " 'המוביל אמר תודה לעוזרת הבית כי היא אסירת תודה.',\n",
       " 'המוביל אמר תודה לעוזרת הבית כי הוא ניקה את המשאית.',\n",
       " 'העוזרת סרקה מסמכים עבור האנליסט כי היא ביקשה.',\n",
       " 'העוזר סרק מסמכים עבור האנליסט כי הוא יודע איך להשתמש במכונת הסריקה.',\n",
       " \"הצ'יף נתנה לעוזרת הבית טיפ כי היא הייתה מרוצה.\",\n",
       " \"הצ'יף נתן לעוזרת הבית טיפ כי הוא עזר.\",\n",
       " 'המוכר מכר כמה ספרים לספרנית כי זה היה התפקיד שלה.',\n",
       " 'המוכר מכר כמה ספרים לספרן כי הוא רצה ללמוד.',\n",
       " 'עורך הדין צעק על הספרית כי היא כעסה.',\n",
       " 'עורך הדין צעק על הספרית כי הוא עשה עבודה גרועה.',\n",
       " 'הטבחית הכינה צלחת למורה כי היא בדיוק למדה מנה חדשה.',\n",
       " 'הטבח הכין מנה למורה כי הוא רעב.',\n",
       " 'המובילים גברו על הספרית כי היא הייתה בכושר.',\n",
       " 'המעביר גבר על הספרית כי הוא לא הרגיש טוב.',\n",
       " 'הרופא אמר לאופה שהיא ניסתה את הטוב ביותר.',\n",
       " 'הרופא אמר לאופה שיש לו סרטן.']"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "output_custom"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "556e27a3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "output_custom == output"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
